import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt


def simple_evaluate_OD_classification(classifier, regressor, dataset):
    acc1, acc2, count = 0, 0, 0
    for id, (I, T) in enumerate(dataset):
        I, N1, N2, X1, X2, Y1, Y2 = I
        leftregx, leftregy, rightregx, rightregy = regressor(I)

        C1, C2 = tf.argmax(classifier(I, leftregx, leftregy), axis=-1), tf.argmax(classifier(I, rightregx, rightregy),
                                                                                  axis=-1)
        acc1 += tf.reduce_sum(tf.where((C1 == N1), 1, 0))
        acc2 += tf.reduce_sum(tf.where((C2 == N2), 1, 0))
        count += I.shape[0]
    return float(acc1 / count), float(acc2 / count)


def simple_evaluate_ODbaseline(network, dataset):
    acc = 0
    count = 0
    for I, T in dataset:
        I, N1, N2, X1, X2, Y1, Y2 = I
        T = tf.subtract(N1, N2)
        O = tf.argmax(network.call(I), axis=-1) - 9
        acc += tf.reduce_sum(tf.where((T == O), 1, 0))
        count += I.shape[0]
    return float(acc / count)

def plot_simple_masked(regressnet, data, generalised=False):
    images = data[0]
    if generalised:
        mult = 1
    else:
        mult = 2
    P1x, P1y, P2x, P2y = regressnet(images)
    Ps = [[P1x, P1y], [P2x, P2y]]

    for i in range(2):
        mux, sigmax = Ps[i][0][:, 0], Ps[i][0][:, 1]
        muy, sigmay = Ps[i][1][:, 0], Ps[i][1][:, 1]

        x1, x2 = tf.math.maximum(mux - mult * sigmax, 0.),  tf.math.minimum(mux + mult * sigmax, 1.)
        y1, y2 = tf.math.maximum(muy - mult * sigmay, 0.),  tf.math.minimum(muy + mult * sigmay, 1.)
        x1 = tf.cast(tf.matmul(tf.expand_dims(x1, -1), tf.ones([1, 128])) * 128, tf.int32)
        x2 = tf.cast(tf.matmul(tf.expand_dims(x2, -1), tf.ones([1, 128])) * 128, tf.int32)
        y1 = tf.cast(tf.matmul(tf.expand_dims(y1, -1), tf.ones([1, 128])) * 128, tf.int32)
        y2 = tf.cast(tf.matmul(tf.expand_dims(y2, -1), tf.ones([1, 128])) * 128, tf.int32)
        mask_rows = []
        mask_cols = []
        for i in range(128):
            mask_rows.append(tf.where(tf.logical_and(y1 <= i, i <= y2), 1., 0.))
            mask_cols.append(tf.where(tf.logical_and(x1 <= i, i <= x2), 1., 0.))
        masky = tf.expand_dims(tf.stack(mask_rows, axis=-2), axis=-1)
        maskx = tf.expand_dims(tf.stack(mask_cols, axis=-1), axis=-1)
        masked_images = images * maskx * masky

        plt.imshow(masked_images[0])
        plt.show()

def simple_complete_evaluation(network, dataset):
    acc = 0
    counter = 0
    for I, T in dataset:
        I, N1, N2, X1, X2, Y1, Y2 = I
        Dist = tf.sqrt(tf.square(X1 - X2) + tf.square(Y1 - Y2))
        for id, xi in enumerate(I):
            x = tf.stack([xi] * 19)
            r = tf.stack([i - 9 for i in range(19)])
            d = tf.stack([Dist[id]] * 19)
            AMC = network.call([x, r, d])
            int_AMC = tf.reduce_mean(AMC, axis=-1)
            pred = tf.argmax(int_AMC)
            acc += (int(pred - 9) == int(N1[id] - N2[id]))
            counter += 1
    return acc / counter

def simple_evaluate_regression_IoU(regressor, dataset, mean=True, generalised=False):
    iou1, iou2 = [], []
    if generalised:
        mult = 1
    else:
        mult = 2

    for I, T in dataset:
        I, N1, N2, X1, X2, Y1, Y2 = I
        P1x, P1y, P2x, P2y = regressor(I)
        Ps = [[P1x, P1y], [P2x, P2y]]

        mux1, sigmax1 = Ps[0][0][:, 0], Ps[0][0][:, 1]
        muy1, sigmay1 = Ps[0][1][:, 0], Ps[0][1][:, 1]
        mux2, sigmax2 = Ps[1][0][:, 0], Ps[1][0][:, 1]
        muy2, sigmay2 = Ps[1][1][:, 0], Ps[1][1][:, 1]

        mux, sigmax = tf.stack([mux1, mux2], axis=-1), tf.stack([sigmax1, sigmax2], axis=-1)
        muy, sigmay = tf.stack([muy1, muy2], axis=-1), tf.stack([sigmay1, sigmay2], axis=-1)

        target1_p1x, target1_p1y = X1 - 14 / 128, Y1 - 14 / 128
        target1_p2x, target1_p2y = X1 + 14 / 128, Y1 + 14 / 128
        target2_p1x, target2_p1y = X2 - 14 / 128, Y2 - 14 / 128
        target2_p2x, target2_p2y = X2 + 14 / 128, Y2 + 14 / 128

        pred1_p1x, pred1_p1y = mux[:, 0] - mult * sigmax[:, 0], muy[:, 0] - mult * sigmay[:, 0]
        pred1_p2x, pred1_p2y = mux[:, 0] + mult * sigmax[:, 0], muy[:, 0] + mult * sigmay[:, 0]
        pred2_p1x, pred2_p1y = mux[:, 1] - mult * sigmax[:, 0], muy[:, 1] - mult * sigmay[:, 0]
        pred2_p2x, pred2_p2y = mux[:, 1] + mult * sigmax[:, 0], muy[:, 1] + mult * sigmay[:, 0]

        pred1_p1x, pred1_p1y = tf.maximum(pred1_p1x, 0.), tf.maximum(pred1_p1y, 0.)
        pred1_p2x, pred1_p2y = tf.minimum(pred1_p2x, 1.), tf.minimum(pred1_p2y, 1.)
        pred2_p1x, pred2_p1y = tf.maximum(pred2_p1x, 0.), tf.maximum(pred2_p1y, 0.)
        pred2_p2x, pred2_p2y = tf.minimum(pred2_p2x, 1.), tf.minimum(pred2_p2y, 1.)

        target_box1 = tf.stack([target1_p1x, target1_p1y, target1_p2x, target1_p2y], axis=-1)
        target_box2 = tf.stack([target2_p1x, target2_p1y, target2_p2x, target2_p2y], axis=-1)
        pred_box1 = tf.stack([pred1_p1x, pred1_p1y, pred1_p2x, pred1_p2y], axis=-1)
        pred_box2 = tf.stack([pred2_p1x, pred2_p1y, pred2_p2x, pred2_p2y], axis=-1)

        if mean:
            iou1.append(tf.reduce_mean(bbox_iou(target_box1, pred_box1)))
            iou2.append(tf.reduce_mean(bbox_iou(target_box2, pred_box2)))
        else:
            iou1.append(tf.squeeze(bbox_iou(target_box1, pred_box1)))
            iou2.append(tf.squeeze(bbox_iou(target_box2, pred_box2)))
    if mean:
        return np.mean(iou1), np.mean(iou2)
    return tf.stack(iou1, axis=-1), tf.stack(iou2, axis=-1)
